{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "(tut_one)=\n",
    "\n",
    "# **Full Network Quantization**  \n",
    "\n",
    "In this tutorial, we will take a sample network with ResNet-like network and perform ``full`` network quantization.\n",
    "\n",
    "\n",
    "```{eval-rst}\n",
    "\n",
    ".. admonition:: Goal\n",
    "    :class: note\n",
    "\n",
    "    #. Take a resnet-like model and train on cifar10 dataset.\n",
    "    #. Perform full model quantization.\n",
    "    #. Fine-tune to recover model accuracy.\n",
    "    #. Save both original and quantized model while performing ONNX conversion.\n",
    "\n",
    "```\n",
    "---"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "#\n",
    "# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n",
    "# SPDX-License-Identifier: Apache-2.0\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "#\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow_quantization import quantize_model\n",
    "import tiny_resnet\n",
    "from tensorflow_quantization import utils\n",
    "import os\n",
    "\n",
    "tf.keras.backend.clear_session()\n",
    "\n",
    "# Create folders to save TF and ONNX models\n",
    "assets = utils.CreateAssetsFolders(os.path.join(os.getcwd(), \"tutorials\"))\n",
    "assets.add_folder(\"simple_network_quantize_full\")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "# Load CIFAR10 dataset\n",
    "cifar10 = tf.keras.datasets.cifar10\n",
    "(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()\n",
    "\n",
    "# Normalize the input image so that each pixel value is between 0 and 1.\n",
    "train_images = train_images / 255.0\n",
    "test_images = test_images / 255.0"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.\n"
     ]
    }
   ],
   "source": [
    "nn_model_original = tiny_resnet.model()\n",
    "tf.keras.utils.plot_model(nn_model_original, to_file = assets.simple_network_quantize_full.fp32 + \"/model.png\")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "1407/1407 [==============================] - 16s 9ms/step - loss: 1.7653 - accuracy: 0.3622 - val_loss: 1.5516 - val_accuracy: 0.4552\n",
      "Epoch 2/10\n",
      "1407/1407 [==============================] - 13s 9ms/step - loss: 1.4578 - accuracy: 0.4783 - val_loss: 1.3877 - val_accuracy: 0.5042\n",
      "Epoch 3/10\n",
      "1407/1407 [==============================] - 13s 9ms/step - loss: 1.3499 - accuracy: 0.5193 - val_loss: 1.3066 - val_accuracy: 0.5342\n",
      "Epoch 4/10\n",
      "1407/1407 [==============================] - 13s 9ms/step - loss: 1.2736 - accuracy: 0.5486 - val_loss: 1.2636 - val_accuracy: 0.5550\n",
      "Epoch 5/10\n",
      "1407/1407 [==============================] - 13s 9ms/step - loss: 1.2101 - accuracy: 0.5732 - val_loss: 1.2121 - val_accuracy: 0.5670\n",
      "Epoch 6/10\n",
      "1407/1407 [==============================] - 12s 9ms/step - loss: 1.1559 - accuracy: 0.5946 - val_loss: 1.1753 - val_accuracy: 0.5844\n",
      "Epoch 7/10\n",
      "1407/1407 [==============================] - 13s 9ms/step - loss: 1.1079 - accuracy: 0.6101 - val_loss: 1.1143 - val_accuracy: 0.6076\n",
      "Epoch 8/10\n",
      "1407/1407 [==============================] - 13s 9ms/step - loss: 1.0660 - accuracy: 0.6272 - val_loss: 1.0965 - val_accuracy: 0.6158\n",
      "Epoch 9/10\n",
      "1407/1407 [==============================] - 13s 9ms/step - loss: 1.0271 - accuracy: 0.6392 - val_loss: 1.1100 - val_accuracy: 0.6122\n",
      "Epoch 10/10\n",
      "1407/1407 [==============================] - 13s 9ms/step - loss: 0.9936 - accuracy: 0.6514 - val_loss: 1.0646 - val_accuracy: 0.6304\n",
      "Baseline FP32 model test accuracy: 61.65\n"
     ]
    }
   ],
   "source": [
    "# Train original classification model\n",
    "nn_model_original.compile(\n",
    "    optimizer=tiny_resnet.optimizer(lr=1e-4),\n",
    "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    metrics=[\"accuracy\"],\n",
    ")\n",
    "\n",
    "nn_model_original.fit(\n",
    "    train_images, train_labels, batch_size=32, epochs=10, validation_split=0.1\n",
    ")\n",
    "\n",
    "# Get baseline model accuracy\n",
    "_, baseline_model_accuracy = nn_model_original.evaluate(\n",
    "    test_images, test_labels, verbose=0\n",
    ")\n",
    "baseline_model_accuracy = round(100 * baseline_model_accuracy, 2)\n",
    "print(\"Baseline FP32 model test accuracy: {}\".format(baseline_model_accuracy))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /home/nvidia/PycharmProjects/tensorrt_qat/docs/source/notebooks/tutorials/simple_network_quantize_full/fp32/saved_model/assets\n",
      "WARNING:tensorflow:From /home/nvidia/PycharmProjects/tensorrt_qat/venv38_tf2.8_newPR/lib/python3.8/site-packages/tf2onnx/tf_loader.py:711: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
      "ONNX conversion Done!\n"
     ]
    }
   ],
   "source": [
    "# Save TF FP32 original model\n",
    "tf.keras.models.save_model(nn_model_original, assets.simple_network_quantize_full.fp32_saved_model)\n",
    "\n",
    "# Convert FP32 model to ONNX\n",
    "utils.convert_saved_model_to_onnx(saved_model_dir = assets.simple_network_quantize_full.fp32_saved_model, onnx_model_path = assets.simple_network_quantize_full.fp32_onnx_model)\n",
    "\n",
    "# Quantize model\n",
    "q_nn_model = quantize_model(model=nn_model_original)\n",
    "q_nn_model.compile(\n",
    "    optimizer=tiny_resnet.optimizer(lr=1e-4),\n",
    "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    metrics=[\"accuracy\"],\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy immediately after quantization:50.45, diff:11.199999999999996\n"
     ]
    }
   ],
   "source": [
    "_, q_model_accuracy = q_nn_model.evaluate(test_images, test_labels, verbose=0)\n",
    "q_model_accuracy = round(100 * q_model_accuracy, 2)\n",
    "\n",
    "print(\n",
    "    \"Test accuracy immediately after quantization: {}, diff: {}\".format(\n",
    "        q_model_accuracy, (baseline_model_accuracy - q_model_accuracy)\n",
    "    )\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.\n"
     ]
    }
   ],
   "source": [
    "tf.keras.utils.plot_model(q_nn_model, to_file = assets.simple_network_quantize_full.int8 + \"/model.png\")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2\n",
      "1407/1407 [==============================] - 27s 19ms/step - loss: 0.9625 - accuracy: 0.6630 - val_loss: 1.0430 - val_accuracy: 0.6420\n",
      "Epoch 2/2\n",
      "1407/1407 [==============================] - 25s 18ms/step - loss: 0.9315 - accuracy: 0.6758 - val_loss: 1.0717 - val_accuracy: 0.6336\n",
      "Accuracy after fine-tuning for 2 epochs: 62.27\n",
      "Baseline accuracy (for reference): 61.65\n"
     ]
    }
   ],
   "source": [
    "# Fine-tune quantized model\n",
    "fine_tune_epochs = 2\n",
    "\n",
    "q_nn_model.fit(\n",
    "    train_images,\n",
    "    train_labels,\n",
    "    batch_size=32,\n",
    "    epochs=fine_tune_epochs,\n",
    "    validation_split=0.1,\n",
    ")\n",
    "\n",
    "_, q_model_accuracy = q_nn_model.evaluate(test_images, test_labels, verbose=0)\n",
    "q_model_accuracy = round(100 * q_model_accuracy, 2)\n",
    "print(\n",
    "    \"Accuracy after fine-tuning for {} epochs: {}\".format(\n",
    "        fine_tune_epochs, q_model_accuracy\n",
    "    )\n",
    ")\n",
    "print(\"Baseline accuracy (for reference): {}\".format(baseline_model_accuracy))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:Found untraced functions such as conv2d_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_1_layer_call_fn, conv2d_1_layer_call_and_return_conditional_losses, conv2d_2_layer_call_fn while saving (showing 5 of 18). These functions will not be directly callable after loading.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /home/nvidia/PycharmProjects/tensorrt_qat/docs/source/notebooks/tutorials/simple_network_quantize_full/int8/saved_model/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /home/nvidia/PycharmProjects/tensorrt_qat/docs/source/notebooks/tutorials/simple_network_quantize_full/int8/saved_model/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ONNX conversion Done!\n"
     ]
    }
   ],
   "source": [
    "# Save TF INT8 original model\n",
    "tf.keras.models.save_model(q_nn_model, assets.simple_network_quantize_full.int8_saved_model)\n",
    "\n",
    "# Convert INT8 model to ONNX\n",
    "utils.convert_saved_model_to_onnx(saved_model_dir = assets.simple_network_quantize_full.int8_saved_model, onnx_model_path = assets.simple_network_quantize_full.int8_onnx_model)\n",
    "\n",
    "tf.keras.backend.clear_session()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "```{note}\n",
    "ONNX files can be visualized with [Netron](https://netron.app/).\n",
    "```"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "4442e1c252d743d7d1ab28567e302ebe8a15da81acb5d7e7894db75e10bdb29d"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}